#!/usr/bin/env python

# test-issue-url: https://github.com/rethinkdb/rethinkdb/issues/1774
# test-description: reads and writes can still proceed while a table is sharded

import os, sys, time

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, 'common')))
import driver, utils

r = utils.import_python_driver()

# -- setup

dbName, tableName = utils.get_test_db_table()

server = driver.Process(console_output=True)
conn = r.connect(host=server.host, port=server.driver_port)

# make sure we have a clean table with no shards

if dbName not in r.db_list().run(conn):
	r.db_create(dbName).run(conn)

if tableName in r.db(dbName).table_list().run(conn):
    r.db(dbName).table_drop(tableName).run(conn)
r.db(dbName).table_create(tableName).run(conn)

# -- thread methods

class FillTable(utils.PerformContinuousAction):
	def runAction(self):
		tableName = os.path.splitext(os.path.basename(__file__))[0]
		if 'tableName' in self.kwargs:
			tableName = self.kwargs['tableName']
		r.db(dbName).table(tableName).insert({'id':self.successCount + 1, 'data':self.successCount + 1}, durability='hard', conflict='error').run(self.connection)

class WaitForTable(utils.PerformContinuousAction):
    def runAction(self):
        tableName = os.path.splitext(os.path.basename(__file__))[0]
        if 'tableName' in self.kwargs:
            tableName = self.kwargs['tableName']
        timeout = 10
        if 'timeout' in self.kwargs:
            timeout = self.kwargs['timeout']
        deadline = time.time() + timeout
        
        while time.time() < deadline:
            try:
                r.db(dbName).table(tableName).limit(1).run(self.connection)
                break
            except r.ReqlRuntimeError:
                pass
            except Exception as e:
                self.recordError(e)
                break
        if time.time() >= deadline:
            self.recordError('Timed out after waiting %s seconds for table %s to be read' % (timeout, tableName))

# -- run the test

# - start filling the table

fillTableProcess = FillTable(connection=r.connect(host=server.host, port=server.driver_port), tableName=tableName)
time.sleep(3) # give it a moment to have something

# - start reading the table

readTableProcess = utils.PerformContinuousAction(connection=r.connect(host=server.host, port=server.driver_port), action=r.db(dbName).table(tableName).sample(1))

# - perform the shard

shardStartTime = time.time()
tableId = r.db(dbName).table(tableName).config()['id'].run(conn)
shardsArray = [{'primary_replica':server.name, 'replicas':[server.name]} for _ in range(4)]
assert r.db('rethinkdb').table('table_config').get(tableId).update({'shards':shardsArray}).run(conn)['errors'] == 0

# -- wait for the table to be ready

waitTableTimeout = 10
waitTableProcess = WaitForTable(connection=r.connect(host=server.host, port=server.driver_port), database=dbName, timeout=waitTableTimeout)
waitTableProcess.join(timeout=waitTableTimeout)
if waitTableProcess.errorCount != 0:
    allPassed = False
    for errorMessage in waitTableProcess.recordedErrors:
    	sys.stderr.write('Failure waiting for table to shard: %s\n' % errorMessage)
else:
    print('Success: %.2f seconds after shard command the table was ready' % ((waitTableProcess.startTime - shardStartTime) + waitTableProcess.duration))

# -- wind down the processes

time.sleep(1) # another second to add more data

fillTableProcess.stop()
readTableProcess.stop()

# -- report on errors

print('Created %d records, %d read, in %.2f seconds' % (fillTableProcess.successCount, readTableProcess.successCount, fillTableProcess.duration))
allPassed = True;

# - fill

for errorMessage, errorCount in fillTableProcess.errorSummary().items():
	allPassed = False
	sys.stderr.write('Failure while writing: %s x %s\n' % (errorCount, errorMessage))

# - read

for errorMessage, errorCount in readTableProcess.errorSummary().items():
	allPassed = False
	sys.stderr.write('Failure while reading: %s x %s\n' % (errorCount, errorMessage))

# - connection is still valid

try:
	r.db_list().run(conn)
	print('Success: conection is still valid')
except r.errors.ReqlDriverError:
	allPassed = False
	sys.stderr.write('Failure: The database connnection went stale\n')
	r.connect(host=server.host, port=server.driver_port)

# - everything made it to the database that we expected

try:
    actualRecordCount = r.db(dbName).table(tableName).count().run(conn)
    if actualRecordCount != fillTableProcess.successCount:
        allPassed = False
        sys.stderr.write('Failure: The count of records in the table (%d) does not match the number we added (%d)\n' % (actualRecordCount, fillTableProcess.successCount))
    else:
        print('Success: correct number of acknowledged records: %d' % actualRecordCount)
    
    expectedSum = (fillTableProcess.successCount * (fillTableProcess.successCount + 1)) / 2
    actualSum = r.db(dbName).table(tableName).sum('data').run(conn)
    if expectedSum != actualSum:
        allPassed = False
        sys.stderr.write('Failure: The sum of the data in the table (%d) does not match the number we expected (%d)\n' % (actualSum, expectedSum))
    else:
        print('Success: correct sum of acknowledged records: %d' % expectedSum)
except Exception as e:
    allPassed = False
    sys.stderr.write('Failure: unable to connect to the table to review data: %s %s' % (e.__class__.__name__, str(e)))
    time.sleep(600)

# -- wind down the server

server.check_and_stop()

# --

if allPassed is False:
	sys.exit('Test failed!')

print('Test passed')
